
package com.shiku.imserver.common.http;


import java.io.UnsupportedEncodingException;
import java.net.URLDecoder;
import java.nio.ByteBuffer;
import java.util.HashMap;
import java.util.Map;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.tio.core.ChannelContext;
import org.tio.core.Tio;
import org.tio.core.exception.AioDecodeException;
import org.tio.http.common.HttpConfig;
import org.tio.http.common.HttpConst;
import org.tio.http.common.Method;
import org.tio.http.common.RequestLine;
import org.tio.http.common.utils.HttpParseUtils;
import org.tio.utils.hutool.StrUtil;


public class HttpRequestDecoder {
    public enum Step {
        firstline, header, body;

    }


    private static Logger log = LoggerFactory.getLogger(HttpRequestDecoder.class);


    public static final int MAX_LENGTH_OF_HEADER = 20480;


    public static final int MAX_LENGTH_OF_HEADERLINE = 20480;


    public static final int MAX_LENGTH_OF_REQUESTLINE = 20480;


    public static HttpRequest decode(ByteBuffer buffer, int limit, int position, int readableLength, ChannelContext channelContext, HttpConfig httpConfig) throws AioDecodeException {

        Map<String, String> headers = new HashMap<>();

        int contentLength = 0;

        byte[] bodyBytes = null;

        StringBuilder headerSb = null;

        RequestLine firstLine = null;

        boolean appendRequestHeaderString = httpConfig.isAppendRequestHeaderString();


        if (appendRequestHeaderString) {

            headerSb = new StringBuilder(512);

        }


        firstLine = parseRequestLine(buffer, channelContext);

        if (firstLine == null) {

            return null;

        }


        HttpRequest httpRequest = new HttpRequest(channelContext.getClientNode());

        httpRequest.setRequestLine(firstLine);

        httpRequest.setChannelContext(channelContext);

        httpRequest.setHttpConfig(httpConfig);


        boolean headerCompleted = parseHeaderLine(buffer, headers, 0, httpConfig);

        if (!headerCompleted) {

            return null;

        }

        String contentLengthStr = headers.get("content-length");


        if (StrUtil.isBlank(contentLengthStr)) {

            contentLength = 0;

        } else {

            contentLength = Integer.parseInt(contentLengthStr);

            if (contentLength > httpConfig.getMaxLengthOfPostBody()) {

                throw new AioDecodeException("post body length is too big[" + contentLength + "], max length is " + httpConfig.getMaxLengthOfPostBody() + " byte");

            }

        }


        int headerLength = buffer.position() - position;

        int allNeedLength = headerLength + contentLength;


        if (readableLength < allNeedLength) {

            channelContext.setPacketNeededLength(Integer.valueOf(allNeedLength));

            return null;

        }


        if (httpConfig.checkHost &&
                !headers.containsKey("host")) {

            throw new AioDecodeException("there is no host header");

        }


        if (appendRequestHeaderString) {

            httpRequest.setHeaderString(headerSb.toString());

        } else {

            httpRequest.setHeaderString("");

        }


        httpRequest.setHeaders(headers);

        if (Tio.IpBlacklist.isInBlacklist(channelContext.groupContext, httpRequest.getClientIp())) {

            throw new AioDecodeException("[" + httpRequest.getClientIp() + "] in black list");

        }


        httpRequest.setContentLength(contentLength);


        String connection = headers.get("connection");

        if (connection != null) {

            httpRequest.setConnection(connection.toLowerCase());

        }


        if (StrUtil.isNotBlank(firstLine.queryString)) {

            decodeParams(httpRequest.getParams(), firstLine.queryString, httpRequest.getCharset(), channelContext);

        }


        if (contentLength != 0) {


            bodyBytes = new byte[contentLength];

            buffer.get(bodyBytes);

            httpRequest.setBody(bodyBytes);


            parseBody(httpRequest, firstLine, bodyBytes, channelContext, httpConfig);

        }


        return httpRequest;

    }


    public static void decodeParams(Map<String, Object[]> params, String queryString, String charset, ChannelContext channelContext) {

        if (StrUtil.isBlank(queryString)) {

            return;

        }


        String[] keyvalues = queryString.split("&");

        for (String keyvalue : keyvalues) {

            String[] keyvalueArr = keyvalue.split("=");

            if (keyvalueArr.length == 2) {


                String key = keyvalueArr[0];

                String value = null;

                try {

                    value = URLDecoder.decode(keyvalueArr[1], charset);

                } catch (UnsupportedEncodingException e) {

                    log.error(channelContext.toString(), e);

                }


                Object[] existValue = params.get(key);

                if (existValue != null) {

                    String[] newExistValue = new String[existValue.length + 1];

                    System.arraycopy(existValue, 0, newExistValue, 0, existValue.length);

                    newExistValue[newExistValue.length - 1] = value;

                    params.put(key, newExistValue);

                } else {

                    String[] newExistValue = {value};

                    params.put(key, newExistValue);

                }

            }

        }

    }


    private static void parseBody(HttpRequest httpRequest, RequestLine firstLine, byte[] bodyBytes, ChannelContext channelContext, HttpConfig httpConfig) throws AioDecodeException {

        String contentType, initboundary;

        parseBodyFormat(httpRequest, httpRequest.getHeaders());

        HttpConst.RequestBodyFormat bodyFormat = httpRequest.getBodyFormat();


        httpRequest.setBody(bodyBytes);


        switch (bodyFormat) {

            case MULTIPART:

                if (log.isInfoEnabled()) {

                    String str = null;

                    if (bodyBytes != null && bodyBytes.length > 0 &&
                            log.isDebugEnabled()) {

                        try {

                            str = new String(bodyBytes, httpRequest.getCharset());

                            log.debug("{} multipart body value\r\n{}", channelContext, str);

                        } catch (UnsupportedEncodingException e) {

                            log.error(channelContext.toString(), e);

                        }

                    }

                }


                contentType = httpRequest.getHeader("content-type");

                initboundary = HttpParseUtils.getSubAttribute(contentType, "boundary");

                log.debug("{}, initboundary:{}", channelContext, initboundary);

                HttpMultiBodyDecoder.decode(httpRequest, firstLine, bodyBytes, initboundary, channelContext, httpConfig);

                return;

        }


        String bodyString = null;

        if (bodyBytes != null && bodyBytes.length > 0) {

            try {

                bodyString = new String(bodyBytes, httpRequest.getCharset());

                httpRequest.setBodyString(bodyString);

                if (log.isInfoEnabled()) {

                    log.info("{} body value\r\n{}", channelContext, bodyString);

                }

            } catch (UnsupportedEncodingException e) {

                log.error(channelContext.toString(), e);

            }

        }


        if (bodyFormat == HttpConst.RequestBodyFormat.URLENCODED) {

            parseUrlencoded(httpRequest, firstLine, bodyBytes, bodyString, channelContext);

        }

    }


    public static void parseBodyFormat(HttpRequest httpRequest, Map<String, String> headers) {

        String contentType = headers.get("content-type");

        String Content_Type = null;

        if (contentType != null) {

            Content_Type = contentType.toLowerCase();

        }


        if (Content_Type.startsWith("text/plain")) {

            httpRequest.setBodyFormat(HttpConst.RequestBodyFormat.TEXT);

        } else if (Content_Type.startsWith("multipart/form-data")) {

            httpRequest.setBodyFormat(HttpConst.RequestBodyFormat.MULTIPART);

        } else {

            httpRequest.setBodyFormat(HttpConst.RequestBodyFormat.URLENCODED);

        }


        if (StrUtil.isNotBlank(Content_Type)) {

            String charset = HttpParseUtils.getSubAttribute(Content_Type, "charset");

            if (StrUtil.isNotBlank(charset)) {

                httpRequest.setCharset(charset);

            } else {

                httpRequest.setCharset("utf-8");

            }

        }

    }


    public static boolean parseHeaderLine(ByteBuffer buffer, Map<String, String> headers, int hasReceivedHeaderLength, HttpConfig httpConfig) throws AioDecodeException {

        if (!buffer.hasArray()) {

            return parseHeaderLine2(buffer, headers, hasReceivedHeaderLength, httpConfig);

        }


        byte[] allbs = buffer.array();

        int initPosition = buffer.position();

        int lastPosition = initPosition;

        int remaining = buffer.remaining();

        if (remaining == 0) {
            return false;
        }

        if (remaining > 1) {

            byte b1 = buffer.get();

            byte b2 = buffer.get();

            if (13 == b1 && 10 == b2) {
                return true;
            }

            if (10 == b1) {

                return true;

            }

        } else if (10 == buffer.get()) {

            return true;

        }


        String name = null;

        String value = null;

        boolean hasValue = false;


        boolean needIteration = false;

        while (buffer.hasRemaining()) {

            byte b = buffer.get();

            if (name == null) {

                if (b == 58) {

                    int len = buffer.position() - lastPosition - 1;

                    name = new String(allbs, lastPosition, len);

                    lastPosition = buffer.position();
                    continue;

                }
                if (b == 10) {

                    byte lastByte = buffer.get(buffer.position() - 2);

                    int len = buffer.position() - lastPosition - 1;

                    if (lastByte == 13) {

                        len = buffer.position() - lastPosition - 2;

                    }

                    name = new String(allbs, lastPosition, len);

                    lastPosition = buffer.position();

                    headers.put(name.toLowerCase(), "");


                    needIteration = true;
                    break;

                }

                continue;

            }

            if (value == null) {

                if (b == 10) {

                    byte lastByte = buffer.get(buffer.position() - 2);

                    int len = buffer.position() - lastPosition - 1;

                    if (lastByte == 13) {

                        len = buffer.position() - lastPosition - 2;

                    }

                    value = new String(allbs, lastPosition, len);

                    lastPosition = buffer.position();


                    headers.put(name.toLowerCase(), StrUtil.trimEnd(value));

                    needIteration = true;

                    break;

                }

                if (!hasValue && b == 32) {

                    lastPosition = buffer.position();
                    continue;

                }

                hasValue = true;

            }

        }


        int lineLength = buffer.position() - initPosition;


        if (lineLength > 20480) {

            throw new AioDecodeException("header line is too long, max length of header line is 20480");

        }


        if (needIteration) {

            int headerLength = lineLength + hasReceivedHeaderLength;


            if (headerLength > 20480) {

                throw new AioDecodeException("header is too long, max length of header is 20480");

            }

            return parseHeaderLine(buffer, headers, headerLength, httpConfig);

        }


        return false;

    }


    private static boolean parseHeaderLine2(ByteBuffer buffer, Map<String, String> headers, int headerLength, HttpConfig httpConfig) throws AioDecodeException {

        int initPosition = buffer.position();

        int lastPosition = initPosition;

        int remaining = buffer.remaining();

        if (remaining == 0) {
            return false;
        }

        if (remaining > 1) {

            byte b1 = buffer.get();

            byte b2 = buffer.get();

            if (13 == b1 && 10 == b2) {
                return true;
            }

            if (10 == b1) {

                return true;

            }

        } else if (10 == buffer.get()) {

            return true;

        }


        String name = null;

        String value = null;

        boolean hasValue = false;


        boolean needIteration = false;

        while (buffer.hasRemaining()) {

            byte b = buffer.get();

            if (name == null) {

                if (b == 58) {

                    int nowPosition = buffer.position();

                    byte[] bs = new byte[nowPosition - lastPosition - 1];

                    buffer.position(lastPosition);

                    buffer.get(bs);

                    name = new String(bs);

                    lastPosition = nowPosition;

                    buffer.position(nowPosition);
                    continue;

                }
                if (b == 10) {

                    int nowPosition = buffer.position();

                    byte[] bs = null;

                    byte lastByte = buffer.get(nowPosition - 2);


                    if (lastByte == 13) {

                        bs = new byte[nowPosition - lastPosition - 2];

                    } else {

                        bs = new byte[nowPosition - lastPosition - 1];

                    }


                    buffer.position(lastPosition);

                    buffer.get(bs);

                    name = new String(bs);

                    lastPosition = nowPosition;

                    buffer.position(nowPosition);


                    headers.put(name.toLowerCase(), null);

                    needIteration = true;

                    break;

                }

                continue;

            }

            if (value == null) {

                if (b == 10) {

                    int nowPosition = buffer.position();

                    byte[] bs = null;

                    byte lastByte = buffer.get(nowPosition - 2);


                    if (lastByte == 13) {

                        bs = new byte[nowPosition - lastPosition - 2];

                    } else {

                        bs = new byte[nowPosition - lastPosition - 1];

                    }


                    buffer.position(lastPosition);

                    buffer.get(bs);

                    value = new String(bs);

                    lastPosition = nowPosition;

                    buffer.position(nowPosition);


                    headers.put(name.toLowerCase(), StrUtil.trimEnd(value));

                    needIteration = true;


                    break;

                }

                if (!hasValue && b == 32) {

                    lastPosition = buffer.position();
                    continue;

                }

                hasValue = true;

            }

        }


        if (needIteration) {

            int myHeaderLength = buffer.position() - initPosition;

            if (myHeaderLength > 20480) {

                throw new AioDecodeException("header is too long");

            }

            return parseHeaderLine(buffer, headers, myHeaderLength + headerLength, httpConfig);

        }


        if (remaining > 20480) {

            throw new AioDecodeException("header line is too long");

        }

        return false;

    }


    public static RequestLine parseRequestLine(ByteBuffer buffer, ChannelContext channelContext) throws AioDecodeException {

        if (!buffer.hasArray()) {

            return parseRequestLine2(buffer, channelContext);

        }


        byte[] allbs = buffer.array();


        int initPosition = buffer.position();


        String methodStr = null;

        String pathStr = null;

        String queryStr = null;

        String protocol = null;

        String version = null;

        int lastPosition = initPosition;

        while (buffer.hasRemaining()) {

            byte b = buffer.get();

            if (methodStr == null) {

                if (b == 32) {

                    int len = buffer.position() - lastPosition - 1;

                    methodStr = new String(allbs, lastPosition, len);

                    lastPosition = buffer.position();

                }
                continue;

            }

            if (pathStr == null) {

                if (b == 32 || b == 63) {

                    int len = buffer.position() - lastPosition - 1;

                    pathStr = new String(allbs, lastPosition, len);

                    lastPosition = buffer.position();


                    if (b == 32) {
                        queryStr = "";
                    }

                }

                continue;

            }

            if (queryStr == null) {

                if (b == 32) {

                    int len = buffer.position() - lastPosition - 1;

                    queryStr = new String(allbs, lastPosition, len);

                    lastPosition = buffer.position();

                }
                continue;

            }

            if (protocol == null) {

                if (b == 47) {

                    int len = buffer.position() - lastPosition - 1;

                    protocol = new String(allbs, lastPosition, len);

                    lastPosition = buffer.position();

                }
                continue;

            }

            if (version == null &&
                    b == 10) {

                byte lastByte = buffer.get(buffer.position() - 2);

                int len = buffer.position() - lastPosition - 1;

                if (lastByte == 13) {

                    len = buffer.position() - lastPosition - 2;

                }

                version = new String(allbs, lastPosition, len);

                lastPosition = buffer.position();


                RequestLine requestLine = new RequestLine();

                Method method = Method.from(methodStr);

                requestLine.setMethod(method);

                requestLine.setPath(pathStr);

                requestLine.setInitPath(pathStr);

                requestLine.setQueryString(queryStr);

                requestLine.setProtocol(protocol);

                requestLine.setVersion(version);


                return requestLine;

            }

        }


        if (buffer.position() - initPosition > 20480) {

            throw new AioDecodeException("request line is too long");

        }

        return null;

    }


    private static RequestLine parseRequestLine2(ByteBuffer buffer, ChannelContext channelContext) throws AioDecodeException {

        int initPosition = buffer.position();


        String methodStr = null;

        String pathStr = null;

        String queryStr = null;

        String protocol = null;

        String version = null;

        int lastPosition = initPosition;

        while (buffer.hasRemaining()) {

            byte b = buffer.get();

            if (methodStr == null) {

                if (b == 32) {

                    int nowPosition = buffer.position();

                    byte[] bs = new byte[nowPosition - lastPosition - 1];

                    buffer.position(lastPosition);

                    buffer.get(bs);

                    methodStr = new String(bs);

                    lastPosition = nowPosition;

                    buffer.position(nowPosition);

                }
                continue;

            }

            if (pathStr == null) {

                if (b == 32 || b == 63) {

                    int nowPosition = buffer.position();

                    byte[] bs = new byte[nowPosition - lastPosition - 1];

                    buffer.position(lastPosition);

                    buffer.get(bs);

                    pathStr = new String(bs);

                    lastPosition = nowPosition;

                    buffer.position(nowPosition);


                    if (b == 32) {
                        queryStr = "";
                    }

                }

                continue;

            }

            if (queryStr == null) {

                if (b == 32) {

                    int nowPosition = buffer.position();

                    byte[] bs = new byte[nowPosition - lastPosition - 1];

                    buffer.position(lastPosition);

                    buffer.get(bs);

                    queryStr = new String(bs);

                    lastPosition = nowPosition;

                    buffer.position(nowPosition);

                }
                continue;

            }

            if (protocol == null) {

                if (b == 47) {

                    int nowPosition = buffer.position();

                    byte[] bs = new byte[nowPosition - lastPosition - 1];

                    buffer.position(lastPosition);

                    buffer.get(bs);

                    protocol = new String(bs);

                    lastPosition = nowPosition;

                    buffer.position(nowPosition);

                }
                continue;

            }

            if (version == null &&
                    b == 10) {

                int nowPosition = buffer.position();

                byte[] bs = null;

                byte lastByte = buffer.get(nowPosition - 2);


                if (lastByte == 13) {

                    bs = new byte[nowPosition - lastPosition - 2];

                } else {

                    bs = new byte[nowPosition - lastPosition - 1];

                }


                buffer.position(lastPosition);

                buffer.get(bs);

                version = new String(bs);

                lastPosition = nowPosition;

                buffer.position(nowPosition);


                RequestLine requestLine = new RequestLine();

                Method method = Method.from(methodStr);

                requestLine.setMethod(method);

                requestLine.setPath(pathStr);

                requestLine.setInitPath(pathStr);

                requestLine.setQueryString(queryStr);

                requestLine.setProtocol(protocol);

                requestLine.setVersion(version);


                return requestLine;

            }

        }


        if (buffer.position() - initPosition > 20480) {

            throw new AioDecodeException("request line is too long");

        }

        return null;

    }


    private static void parseUrlencoded(HttpRequest httpRequest, RequestLine firstLine, byte[] bodyBytes, String bodyString, ChannelContext channelContext) {

        decodeParams(httpRequest.getParams(), bodyString, httpRequest.getCharset(), channelContext);

    }

}


